﻿using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;

namespace TreeDemo
{
    /* 1. 树形结构数据关键属性：Children或ParentId 至少有一个，以此来连接父子关系
     * 2. 树形数据 关键属性：Children，则应用Flatten()时，需T:ITreeChildren<T>
     * 3. 平铺数据 关键属性：ParentId，则应用ToTree()时，需T:ITreeParentId<T>
     */

    public interface TTreeKey<TKey>
    {
        TKey Id { get; set; }
    }

    #region 带 Children 的树形实体接口
    /// <summary>
    /// 带子级的树形实体接口
    /// </summary>
    /// <typeparam name="T"></typeparam>
    public interface ITreeChildren<T, TKey> : TTreeKey<TKey>
    {
        /// <summary>
        /// 子级
        /// </summary>
        IEnumerable<T> Children { get; set; }
    }
    #endregion

    #region 带 ParentId 的树形实体接口
    /// <summary>
    /// 带 ParentId 的树形实体接口
    /// </summary>
    public interface ITreeParentId<T, TKey> : TTreeKey<TKey>
    {
        /// <summary>
        /// 父级id
        /// </summary>
        public TKey ParentId { get; set; }
    }
    #endregion



    public static class TreeHelper
    {
        //基于ITreeChildren
        #region Flatten
        public static IEnumerable<T> Flatten<T, TKey>(this IEnumerable<T> items) where T : ITreeChildren<T, TKey>, new()
        {
            foreach (var item in items)
            {
                yield return item;
                item.Children ??= new List<T>();
                foreach (var child in item.Children.Flatten<T, TKey>())
                {
                    yield return child;
                }
            }
        }
        #endregion

        //基于ITreeParentId
        #region ToTree
        public static List<T> ToTree<T, TKey>(List<T> treeDataList, TKey parentId = default) where T : ITreeParentId<T, TKey>
        {
            var circularReferenceNode = GetFirstCircularReferenceNode<T, TKey>(treeDataList);
            if (circularReferenceNode != null) throw new ArgumentException($"数据源中存在循环引用项：Id={circularReferenceNode.Id}");
            var data = treeDataList.Where(x => x.ParentId.Equals(parentId));
            var list = new List<T>();
            foreach (var item in data)
            {
                OperationChildData(treeDataList, item);
                list.Add(item);
            }
            return list;

            static void OperationChildData(List<T> treeDataList, ITreeParentId<T, TKey> parentItem)
            {
                var subItems = treeDataList.Where(s => s.ParentId.Equals(parentItem.Id)).ToList();
                foreach (var subItem in subItems)
                {
                    OperationChildData(treeDataList, subItem);
                }
            }
        }
        #endregion

        #region 获得第一个循环引用的Node
        //获得第一个循环引用的Node
        public static T GetFirstCircularReferenceNode<T, TKey>(List<T> nodes) where T : ITreeParentId<T, TKey>
        {
            var visited = new Dictionary<TKey, bool>();

            foreach (var node in nodes)
            {
                if (HasCycle(node.Id, node.Id, nodes, visited))
                {
                    return node;
                }
            }
            return default;

            bool HasCycle(TKey currentId, TKey startId, List<T> nodes, Dictionary<TKey, bool> visited)
            {
                if (visited.ContainsKey(currentId))
                {
                    return currentId.Equals(startId);
                }
                visited[currentId] = true;
                T currentNode = nodes.Find(n => n.Id.Equals(currentId));
                if (currentNode != null)
                {
                    return HasCycle(currentNode.ParentId, startId, nodes, visited);
                }
                visited[currentId] = false;
                return false;
            }
        }
        #endregion


        /// <summary>
        /// 平铺开
        /// </summary>
        /// <typeparam name="T"></typeparam>
        /// <param name="items"></param>
        /// <param name="optionAction">平铺时子级需要做的操作，参数1：子级对象，参数2：父级对象</param>
        /// <returns></returns>
        //public static IEnumerable<T> Flatten<T>(this IEnumerable<T> items, Action<T, T> optionAction = null) where T : class, ITreeChildren<T>
        //{
        //    foreach (var item in items)
        //    {
        //        yield return item;
        //        item.Children ??= new List<T>();
        //        item.Children.ToList().ForEach(c => optionAction?.Invoke(c, item));
        //        foreach (var children in item.Children.Flatten(optionAction))
        //        {
        //            yield return children;
        //        }
        //    }
        //}
        /// <summary>
        /// 平行集合转换成树形结构
        /// </summary>
        /// <typeparam name="T"></typeparam>
        /// <typeparam name="TKey"></typeparam>
        /// <param name="source"></param>
        /// <param name="idSelector"></param>
        /// <param name="pidSelector"></param>
        /// <param name="topValue">根对象parentId的值</param>
        /// <returns></returns>
        public static List<T> ToTree<T, TKey>(this IEnumerable<T> source, Expression<Func<T, TKey>> idSelector, Expression<Func<T, TKey>> pidSelector, TKey topValue = default) where T : ITreeParentId<TKey>, ITreeChildren<T> where TKey : IComparable
        {
            if (source is IQueryable<T> queryable)
            {
                source = queryable.ToList();
            }

            if (idSelector.Body.ToString() == pidSelector.Body.ToString())
            {
                throw new ArgumentException("idSelector和pidSelector不应该为同一字段！");
            }

            var pidFunc = pidSelector.Compile();
            var idFunc = idSelector.Compile();
            source = source.Where(t => t != null);
            var temp = new List<T>();
            foreach (var item in source.Where(item => pidFunc(item) is null || pidFunc(item).Equals(topValue)))
            {
                item.ParentId = default;
                TransData(source, item, idFunc, pidFunc);
                temp.Add(item);
            }

            return temp;
        }
        private static void TransData<T, TKey>(IEnumerable<T> source, T parent, Func<T, TKey> idSelector, Func<T, TKey> pidSelector) where T : ITreeChildren<T> where TKey : IComparable
        {
            var temp = new List<T>();
            foreach (var item in source.Where(item => pidSelector(item)?.Equals(idSelector(parent)) == true))
            {
                TransData(source, item, idSelector, pidSelector);
                if (item is ITreeParentId<TKey> c)
                {
                    //c.ParentId = parent.Children;
                }
                temp.Add(item);
            }

            parent.Children = temp;
        }
    }
}
